{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Number of model parameters: 3.38 million\n",
      "Running on: cuda\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:45<00:00, 17.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/10], Loss: 2.3065, LR: 0.0010\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:45<00:00, 17.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [2/10], Loss: 2.3067, LR: 0.0009\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:46<00:00, 16.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [3/10], Loss: 2.2931, LR: 0.0008\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:46<00:00, 16.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [4/10], Loss: 2.3023, LR: 0.0007\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:46<00:00, 16.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [5/10], Loss: 2.3356, LR: 0.0005\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:45<00:00, 17.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [6/10], Loss: 2.3051, LR: 0.0003\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:45<00:00, 17.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [7/10], Loss: 2.2979, LR: 0.0002\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:45<00:00, 17.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [8/10], Loss: 2.2996, LR: 0.0001\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:45<00:00, 17.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [9/10], Loss: 2.3053, LR: 0.0000\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 782/782 [00:45<00:00, 17.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [10/10], Loss: 2.3023, LR: 0.0000\n",
      "Accuracy of the model on the 10000 test images: 10.0%\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# Define the ViT model\n",
    "class ViT(nn.Module):\n",
    "    def __init__(self, image_size, patch_size, num_classes, dim):\n",
    "        super().__init__()\n",
    "        assert image_size % patch_size == 0, \"image size must be divisible by patch size\"\n",
    "        num_patches = (image_size // patch_size) ** 2\n",
    "        patch_dim = 3 * patch_size ** 2\n",
    "        self.patch_size = patch_size\n",
    "        self.num_patches = num_patches\n",
    "        self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)\n",
    "        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=4), num_layers=12)\n",
    "        self.classification_head = nn.Linear(dim, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        patches = self.patch_embedding(x).flatten(2).transpose(1, 2)\n",
    "        embeddings = self.transformer(patches)\n",
    "        cls_embedding = embeddings[:, 0]\n",
    "        return self.classification_head(cls_embedding)\n",
    "\n",
    "# Define the training parameters\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "batch_size = 64\n",
    "num_epochs = 10\n",
    "learning_rate = 0.001\n",
    "\n",
    "# Load the CIFAR10 dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
    "])\n",
    "train_dataset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=True, transform=transform)\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "\n",
    "#num_training_steps = num_epochs * len(train_loader)\n",
    "#warmup_steps = 500\n",
    "#cosine_decay_steps = num_training_steps - warmup_steps\n",
    "\n",
    "# Initialize the model, optimizer, and loss function\n",
    "model = ViT(image_size=32, patch_size=4, num_classes=10, dim=64).to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# Define the learning rate scheduler\n",
    "#lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=cosine_decay_steps, T_mult=1, eta_min=0.0001)\n",
    "lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)\n",
    "\n",
    "print(f\"Number of model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f} million\")\n",
    "print(f\"Running on: {device}\")\n",
    "\n",
    "# Evaluate the model on the test dataset\n",
    "test_dataset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=True, transform=transform)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "def evaluate(model, test_loader, test_dataset):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        for images, labels in test_loader:\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "        print(f\"Accuracy of the model on the {len(test_dataset)} test images: {100 * correct / total}%\")\n",
    "\n",
    "\n",
    "# Train the model\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for i, (images, labels) in enumerate(tqdm(train_loader)):\n",
    "        # Move the images and labels to the device\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # Zero the gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward pass\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # Update the learning rate\n",
    "    \n",
    "    lr_scheduler.step()\n",
    "    # Print the loss and accuracy after each epoch\n",
    "    print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, LR: {optimizer.param_groups[0]['lr']:.4f}\")\n",
    "    evaluate(model, test_loader, test_dataset)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diffusion",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
